Skip to content

Fix/pyt hy video amd dockerfile#149

Open
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/pyt-hy-video-amd-dockerfile
Open

Fix/pyt hy video amd dockerfile#149
vadseshu wants to merge 2 commits intoROCm:developfrom
vadseshu:fix/pyt-hy-video-amd-dockerfile

Conversation

@vadseshu
Copy link
Copy Markdown
Contributor

Motivation

This PR fixes the flash attention compilation and the steps are taken from AMD ROCM support https://github.com/Dao-AILab/flash-attention

Technical Details

The pyt_hy_video AMD image failed in two places:

Build: rocmProfileData’s rocpd_python Makefile runs pip install --user ., which is invalid inside the base image’s Python venv (/opt/venv), so make install exited with a pip error.
Runtime: pip resolved transformers 5.x as a dependency of distvae, while diffusers==0.32.2 still imports FLAX_WEIGHTS_NAME from transformers.utils, which is no longer exposed in v5. That caused ImportError when loading diffusers pipelines (e.g. hunyuan_video_usp_example.py under torchrun).

Constrain transformers to >=4.44,<5 on the same install line as diffusers / distvae so the resolver stays on a 4.x release compatible with diffusers 0.32.2.
After cloning rocmProfileData, patch rocpd_python/Makefile to replace pip install --user with pip install before make install, so installs target the venv.
Flash-attention: switch to the ROCm Triton path (FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + pip install --no-build-isolation .) so the image can be built without a GPU-scoped arch wheel build in headless CI; the previous pinned SHA / bdist_wheel path is left commented for reference.

Test Plan

Test Result

image

Submission Checklist

vadseshu1 and others added 2 commits April 16, 2026 13:10
Pin transformers to 4.x so diffusers 0.32.2 can import FLAX_WEIGHTS_NAME.
Patch rocmProfileData rocpd_python Makefile to avoid pip install --user
inside the base image venv.

Flash-attention is installed with FLASH_ATTENTION_TRITON_AMD_ENABLE for
headless/CI-friendly builds; replace prior wheel build from pinned SHA.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates AMD PyTorch Docker images to make flash-attention build reliably in headless CI and to prevent a runtime incompatibility between diffusers==0.32.2 and transformers v5 in the pyt_hy_video image.

Changes:

  • Pin transformers to <5 alongside diffusers==0.32.2 in pyt_hy_video to avoid an ImportError at runtime.
  • Switch pyt_hy_video flash-attention installation to the ROCm Triton path and patch rocmProfileData’s rocpd_python/Makefile to avoid pip install --user inside a venv.
  • Add gfx950 to the ROCm arch list in pyt_mochi_inference.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
docker/pyt_mochi_inference.ubuntu.amd.Dockerfile Extends ROCm arch list used for flash-attention wheel build.
docker/pyt_hy_video.ubuntu.amd.Dockerfile Pins transformers to 4.x, changes flash-attention install approach, and patches rocmProfileData install behavior for venv compatibility.
Comments suppressed due to low confidence (1)

docker/pyt_mochi_inference.ubuntu.amd.Dockerfile:40

  • PYTORCH_ROCM_ARCH is a semicolon-delimited string, but it’s expanded unquoted in the later GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | ...) command substitution. In /bin/sh, the ; characters will be treated as command separators, which can break the build. Quote the expansion (e.g., use echo "${PYTORCH_ROCM_ARCH}") or switch to a delimiter that won’t be parsed by the shell and adapt the sed accordingly.
ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
RUN git clone ${FA_REPO}
RUN cd flash-attention \
    && git submodule update --init \
    && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist \

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +25 to +39
#ARG FA_SHA="b3ae4966b2567811880db10d9e040a775b99c7d7"
#ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
#ARG FA_GPU_ARCH=gfx942
#RUN git clone ${FA_REPO} && \
# cd flash-attention && \
# git checkout ${FA_SHA} && \
# git submodule update --init && \
# F='${FA_GPU_ARCH}' && \
# if [ -z "$F" ]; then F=gfx942; fi && \
# if [ "$F" = "native" ]; then F=gfx942; fi && \
# GPU_ARCHS="$F" python3 setup.py bdist_wheel --dist-dir=dist && \
# pip install dist/*.whl;
RUN git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
Copy link
Copy Markdown
Collaborator

@gargrahul gargrahul Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vadseshu let us pin FA to a working commit as suggested.

pip install dist/*.whl;
# flash attn (avoid ARG name PYTORCH_ROCM_ARCH: base image ENV can shadow it and expand to "")
# ROCm flash-attention: FA_GPU_ARCH=native needs a visible GPU at compile time and fails in CI/docker.
# Coerce native/empty to gfx942 for headless CI; for MI350 pass --build-arg FA_GPU_ARCH=gfx950 (needs FA_SHA with gfx950 in setup.py).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants